{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import tensorlayer as tl\n",
    "import json\n",
    "import pylab as pl\n",
    "import random\n",
    "import numpy as np\n",
    "import os\n",
    "import cv2\n",
    "import anno_func\n",
    "%matplotlib inline\n",
    "from bbox import BoundingBox\n",
    "from errors import UnsupportedExtensionError, UnsupportedFormatError\n",
    "from PIL import Image\n",
    "from collections import Counter\n",
    "from tqdm import tqdm\n",
    "from PIL import ImageFile\n",
    "ImageFile.LOAD_TRUNCATED_IMAGES = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "datadir='../../data'\n",
    "filedir=os.path.join(datadir, 'annotations.json')\n",
    "annos=json.loads(open(filedir).read())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class_counter=Counter()\n",
    "classes=set()\n",
    "for imgid in annos['imgs']:\n",
    "    for boxes in annos['imgs'][imgid]['objects']:\n",
    "        class_counter[boxes['category']]+=1\n",
    "        if(class_counter[boxes['category']]>=100 and boxes['category'] not in ['io','po','wo']):\n",
    "            classes.add(boxes['category'])\n",
    "classes=list(classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "new_annos={}\n",
    "new_annos['imgs']={}\n",
    "new_annos['types']=anno_func.type42"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def transform_box(box, new_xmin, new_ymin):\n",
    "    box_dict={}\n",
    "    box_dict['bbox']={'xmin':box['bbox']['xmin']-new_xmin, 'xmax':box['bbox']['xmax']-new_xmin,\n",
    "                     'ymin':box['bbox']['ymin']-new_ymin, 'ymax':box['bbox']['ymax']-new_ymin}\n",
    "    box_dict['category']=box['category']\n",
    "    return box_dict\n",
    "\n",
    "def overlap(box, old_width, old_height,new_xmin, new_ymin, new_width, new_height,tol):\n",
    "    xmin=max(box['bbox']['xmin'],0)\n",
    "    ymin=max(box['bbox']['ymin'],0)\n",
    "    xmax=min(box['bbox']['xmax'],old_width)\n",
    "    ymax=min(box['bbox']['ymax'],old_height)\n",
    "    return xmin>=new_xmin-tol and ymin>=new_ymin-tol and xmax-tol<=new_width+new_xmin and ymax-tol<=new_height+new_ymin\n",
    "    \n",
    "def crop_image(id,old_width=2048, old_height=2048, new_width=512, new_height=512,note='_1',cropped_dir='cropped_train',tol=8,aug=False):\n",
    "    global new_annos\n",
    "    \n",
    "    \n",
    "    ids=0\n",
    "    for box in annos['imgs'][id]['objects']:\n",
    "        objects=[]\n",
    "        \n",
    "        label=box['category']\n",
    "        if label not in classes:\n",
    "            continue\n",
    "        \n",
    "        xmin=box['bbox']['xmin']\n",
    "        ymin=box['bbox']['ymin']\n",
    "        xmax=box['bbox']['xmax']\n",
    "        ymax=box['bbox']['ymax']\n",
    "        if(xmax-xmin>128 or ymax-ymin>128):\n",
    "            continue\n",
    "        new_id=id+'_'+str(ids)+note\n",
    "        new_annos['imgs'][new_id]={}\n",
    "        new_annos['imgs'][new_id]['id']=new_id\n",
    "        path=annos['imgs'][id]['path']\n",
    "        new_xmin=np.random.randint(int(max(0,xmax-new_width)), max(0,int(xmin))+1)\n",
    "        new_ymin=np.random.randint(int(max(0,ymax-new_height)),max(0,int(ymin))+1)\n",
    "        image=Image.open(os.path.join(datadir, path))\n",
    "        bottom=new_ymin+new_height\n",
    "        right=new_xmin+new_width\n",
    "        \n",
    "        if right>=old_width:\n",
    "            right=old_width\n",
    "            new_xmin=right-new_width\n",
    "        if bottom>=old_height:\n",
    "            bottom=old_height\n",
    "            new_ymin=bottom-new_height\n",
    "        \n",
    "        new_img=np.asarray(image.crop((new_xmin, new_ymin, right,bottom)))\n",
    "        if aug:\n",
    "            new_img=img = tl.prepro.illumination(new_img, gamma=(0.5, 1.5), \n",
    "             contrast=(0.5, 1.5), saturation=(0.5, 1.5), is_random=True)\n",
    "        \n",
    "        new_img_path=new_id+'.jpg'\n",
    "        tl.visualize.save_image(new_img, os.path.join(datadir, cropped_dir, new_img_path))\n",
    "        \n",
    "        for obox in annos['imgs'][id]['objects']:\n",
    "            if obox['category'] not in classes:\n",
    "                continue\n",
    "            if overlap(obox, old_width, old_height,new_xmin, new_ymin, new_width, new_height,tol):\n",
    "                box_dict=transform_box(obox, new_xmin, new_ymin)\n",
    "                objects.append(box_dict)\n",
    "       \n",
    "        \n",
    "        \n",
    "        new_annos['imgs'][new_id]['objects']=objects\n",
    "        new_annos['imgs'][new_id]['path']=os.path.join(cropped_dir,new_img_path)\n",
    "        ids+=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "train_dir=os.path.join(datadir,'train')\n",
    "test_dir=os.path.join(datadir,'test')\n",
    "train_id_list=list(map(lambda x:x.split('.')[0],os.listdir(train_dir)))\n",
    "test_id_list=list(map(lambda x:x.split('.')[0],os.listdir(test_dir)))\n",
    "cropped_dir_train='cropped_train'\n",
    "cropped_dir_test='cropped_test'\n",
    "train_new_path=os.path.join(datadir, cropped_dir_train)\n",
    "test_new_path=os.path.join(datadir, cropped_dir_test)\n",
    "if not os.path.exists(train_new_path):\n",
    "    os.mkdir(train_new_path)\n",
    "if not os.path.exists(test_new_path):\n",
    "    os.mkdir(test_new_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "for id in tqdm(test_id_list):\n",
    "    \n",
    "    crop_image(id,2048,2048,512,512,'_1',cropped_dir=cropped_dir_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 6105/6105 [22:08<00:00,  3.88it/s]\n"
     ]
    }
   ],
   "source": [
    "for id in tqdm(train_id_list):\n",
    "    crop_image(id,2048,2048,512,512,'_1',cropped_dir=cropped_dir_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def Data_augmentation(num):\n",
    "    keys=list(new_annos['imgs'].keys())\n",
    "    for imgid in tqdm(keys):\n",
    "        if new_annos['imgs'][imgid]['path'].split('/')[0]==cropp_dir_test:\n",
    "            continue\n",
    "        for box in new_annos['imgs'][imgid]['objects']:\n",
    "            category=box['category']\n",
    "            if class_counter[category]<num:\n",
    "                org_id=imgid.split('_')[0]\n",
    "                n=num//class_counter[category]+1\n",
    "                for k in range(n):\n",
    "                    crop_image(org_id, 2048,2048,512,512, note='_'+str(k+2),cropped_dir=cropped_dir_train, aug=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
